Skip to content

Add Qwen3.5 export support for 0.8B/2B/4B#17800

Merged
lucylq merged 15 commits intopytorch:mainfrom
Phineas1500:qwen3_5_phase1
Mar 6, 2026
Merged

Add Qwen3.5 export support for 0.8B/2B/4B#17800
lucylq merged 15 commits intopytorch:mainfrom
Phineas1500:qwen3_5_phase1

Conversation

@Phineas1500
Copy link
Contributor

@Phineas1500 Phineas1500 commented Mar 2, 2026

Summary

  • add initial Qwen3.5 model integration under examples/models/qwen3_5
  • add HF->Meta checkpoint conversion for Qwen3.5 naming/layout
  • wire Qwen3.5 model types into llama export config and registry
  • add 0.8B, 2B, and 4B configs plus xnnpack fp32 export config
  • add unit tests for Qwen3.5 attention path and weight conversion

Test Plan

  • PYTHONPATH=src pytest -q examples/models/llama/tests/test_qwen3_5_attention.py examples/models/qwen3_5/tests/test_convert_weights.py
  • local smoke export succeeded for Qwen3.5-0.8B: /tmp/qwen35_test/qwen3_5_0_8b_fp32_smoke.pte

cc @mergennachin @iseeyuan @lucylq @helunwencser @tarun292 @kimishpatel @jackzhxng

Copilot AI review requested due to automatic review settings March 2, 2026 23:40
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 2, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17800

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Awaiting Approval, 6 New Failures, 1 Cancelled Job, 5 Pending, 1 Unrelated Failure

As of commit 14a425a with merge base 153adbf (image):

AWAITING APPROVAL - The following workflow needs approval before CI can run:

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla
Copy link

meta-cla bot commented Mar 2, 2026

Hi @Phineas1500!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@meta-cla
Copy link

meta-cla bot commented Mar 2, 2026

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 2, 2026
@Phineas1500
Copy link
Contributor Author

@pytorchbot label "release notes: none"

@pytorch-bot pytorch-bot bot added the release notes: none Do not include this in the release notes label Mar 2, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds initial Qwen3.5 (0.8B/2B/4B) support to the ExecuTorch Llama export pipeline, including checkpoint conversion, model/attention wiring, and basic unit tests.

Changes:

  • Added Qwen3.5 model types to the export config/registry and HF auto-download/convert path.
  • Introduced Qwen3.5 weight conversion + model configs (0.8B/2B/4B) and an XNNPACK fp32 export YAML.
  • Implemented Qwen3.5 full-attention + gated DeltaNet linear-attention in the Llama attention stack, plus unit tests.

Reviewed changes

Copilot reviewed 20 out of 20 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
extension/llm/export/config/llm_config.py Adds ModelType enum entries for Qwen3.5 sizes.
examples/models/llama/export_llama_lib.py Registers Qwen3.5 model names, HF repo IDs, and hooks in Qwen3.5 weight conversion on download.
examples/models/llama/attention.py Adds qwen3_5_full attention and gated_deltanet linear attention implementations; extends RMSNorm usage for q/k norm.
examples/models/llama/model_args.py Adds Qwen3.5-specific linear-attention dimension args + rms_norm_add_unit_offset, and defaulting logic in __post_init__.
examples/models/llama/norm.py Extends RMSNorm to optionally use (1 + weight) scaling.
examples/models/llama/llama_transformer.py Wires rms_norm_add_unit_offset into norms and supports "linear_attention" layers via gated_deltanet.
examples/models/llama/tests/test_qwen3_5_attention.py Adds unit tests for Qwen3.5 full-attention shape and DeltaNet state reset behavior.
examples/models/llama/tests/BUCK Adds BUCK target for the new Qwen3.5 attention unit test.
examples/models/llama/init.py Switches to lazy import pattern for Llama2Model.
examples/models/qwen3_5/convert_weights.py Adds HF → meta checkpoint key conversion for Qwen3.5 (including legacy packed tensor splitting).
examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml Adds fp32 + static-shape XNNPACK export config for Qwen3.5.
examples/models/qwen3_5/config/0_8b_config.json Adds 0.8B model parameter config (hybrid layer layout).
examples/models/qwen3_5/config/2b_config.json Adds 2B model parameter config (hybrid layer layout).
examples/models/qwen3_5/config/4b_config.json Adds 4B model parameter config (hybrid layer layout).
examples/models/qwen3_5/tests/test_convert_weights.py Adds unit test validating key mapping for both full- and linear-attention weights.
examples/models/qwen3_5/tests/init.py Adds package marker/license header for tests.
examples/models/qwen3_5/init.py Adds lazy Qwen3_5Model wrapper + exports convert_weights.
examples/models/qwen3_5/README.md Documents export + runner usage for Qwen3.5 sizes and current limitations.
examples/models/qwen3_5/BUCK Adds BUCK target for the Qwen3.5 python library and deps.
examples/models/BUCK Adds Qwen3.5 to the examples/models python_library deps list.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

self.wo = nn.Linear(
self.n_heads * self.head_dim,
self.dim,
bias=self.attention_qkv_bias,
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AttentionQwen3_5Full sets the output projection (self.wo) bias based on args.attention_qkv_bias. In the rest of the llama attention implementations, the output projection is bias-free, and HF checkpoints typically don’t include an o_proj.bias. If attention_qkv_bias is enabled, this will introduce an extra bias parameter that won’t be loaded from the checkpoint (and will change numerics). Consider forcing wo to bias=False (or gating it on a dedicated flag rather than attention_qkv_bias).

Suggested change
bias=self.attention_qkv_bias,
bias=False,

Copilot uses AI. Check for mistakes.
Comment on lines +823 to +839
for i in range(sequence_length):
q_t = query[:, :, i]
k_t = key[:, :, i]
v_t = value[:, :, i]
g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
beta_t = beta[:, :, i].unsqueeze(-1)

last_recurrent_state = last_recurrent_state * g_t
kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * beta_t
last_recurrent_state = (
last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
)
core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(
dim=-2
)

Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_recurrent_gated_delta_rule uses a Python for loop over sequence_length, which will be very slow for long prefill sequences and hard for compilers/export backends to optimize. Consider rewriting this recurrence using vectorized/scan-style tensor ops (or constraining this path to decode-only seq_len==1 and providing a separate efficient prefill implementation).

Suggested change
for i in range(sequence_length):
q_t = query[:, :, i]
k_t = key[:, :, i]
v_t = value[:, :, i]
g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
beta_t = beta[:, :, i].unsqueeze(-1)
last_recurrent_state = last_recurrent_state * g_t
kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * beta_t
last_recurrent_state = (
last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
)
core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(
dim=-2
)
# This recurrent implementation is intended for decode-only usage,
# where sequence_length == 1. Longer sequences should use a separate
# prefill implementation that can be efficiently vectorized.
if sequence_length != 1:
raise NotImplementedError(
"_recurrent_gated_delta_rule only supports decode-only "
"(sequence_length == 1). Use a dedicated prefill path for "
"longer sequences."
)
# Unrolled single-step recurrence for sequence_length == 1.
q_t = query[:, :, 0]
k_t = key[:, :, 0]
v_t = value[:, :, 0]
g_t = g[:, :, 0].exp().unsqueeze(-1).unsqueeze(-1)
beta_t = beta[:, :, 0].unsqueeze(-1)
last_recurrent_state = last_recurrent_state * g_t
kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * beta_t
last_recurrent_state = (
last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
)
core_attn_out[:, :, 0] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(
dim=-2
)

Copilot uses AI. Check for mistakes.
raise ValueError(
f"Invalid packed in_proj_qkvz shape for {key}: {tuple(value.shape)}"
)
key_dim = (conv_dim - value_dim) // 2
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key_dim is computed but never used, which makes the legacy in_proj_qkvz split logic harder to follow. Either remove it or use it to validate the packed tensor layout (e.g., sanity-check expected q/k/v sizing).

Suggested change
key_dim = (conv_dim - value_dim) // 2

Copilot uses AI. Check for mistakes.
@larryliu0820
Copy link
Contributor

@Phineas1500 thank you for adding this! Have you tried to export with the xnnpack recipe and test out the .pte file?

Copilot AI review requested due to automatic review settings March 3, 2026 00:02
@Phineas1500
Copy link
Contributor Author

@Phineas1500 thank you for adding this! Have you tried to export with the xnnpack recipe and test out the .pte file?

I exported with the XNNPACK recipe for Qwen3.5-0.8B, and I exported 4B in no-backend mode (my computer had memory constraints).

I'm in the process of testing 0.8B. Should have done that before making the PR, and I'll get back to you soon.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 20 out of 20 changed files in this pull request and generated 4 comments.

Comments suppressed due to low confidence (1)

examples/models/llama/norm.py:24

  • RMSNorm.init now accepts add_unit_offset, but the docstring’s Args section still only documents dim and eps. Please update the docstring to include add_unit_offset and describe how it changes the scaling (e.g., output * (1 + weight) when enabled), so callers know when to set it.
    def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False):
        """
        Initialize the RMSNorm normalization layer.

        Args:
            dim (int): The dimension of the input tensor.
            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

        Attributes:
            eps (float): A small value added to the denominator for numerical stability.
            weight (nn.Parameter): Learnable scaling parameter.

        """

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +858 to +865
input_pos = kwargs.get("input_pos")
batch_size, seq_len, _ = x.shape
assert (
batch_size <= self.max_batch_size
), f"batch_size ({batch_size}) exceeds max_batch_size ({self.max_batch_size})"

self._maybe_reset_state(input_pos, batch_size)

Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AttentionGatedDeltaNet maintains internal conv_state/recurrent_state, but when input_pos is omitted it never resets, so outputs can depend on prior forward calls (state leakage across sequences) when the model is run without kv-cache / without passing input_pos. If this attention is only valid in kv-cache mode, consider asserting input_pos is not None (or alternatively resetting state when input_pos is None) to avoid silently incorrect results.

Copilot uses AI. Check for mistakes.
Comment on lines +51 to +59
checkpoint_shards = sorted(set(weight_map.values()))

shard_to_weights = {}
for shard in checkpoint_shards:
shard_to_weights[shard] = load_file(os.path.join(input_dir, shard))

merged_state_dict = {}
for weight_name, shard in weight_map.items():
merged_state_dict[weight_name] = shard_to_weights[shard][weight_name]
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_load_checkpoint_from_safetensors() loads every shard into shard_to_weights before merging, which can significantly increase peak memory for sharded Qwen3.5 checkpoints. Consider switching to the streaming pattern used in examples/models/qwen2_5/convert_weights.py (group keys per shard, load one shard at a time, then del shard_data) to reduce peak RAM usage.

Suggested change
checkpoint_shards = sorted(set(weight_map.values()))
shard_to_weights = {}
for shard in checkpoint_shards:
shard_to_weights[shard] = load_file(os.path.join(input_dir, shard))
merged_state_dict = {}
for weight_name, shard in weight_map.items():
merged_state_dict[weight_name] = shard_to_weights[shard][weight_name]
# Group parameter names by shard so we can load each shard only once.
shard_to_weight_names: Dict[str, list[str]] = {}
for weight_name, shard in weight_map.items():
if shard not in shard_to_weight_names:
shard_to_weight_names[shard] = []
shard_to_weight_names[shard].append(weight_name)
merged_state_dict: Dict[str, torch.Tensor] = {}
# Stream shards: load one shard at a time, copy required tensors, then free it.
for shard, weight_names in shard_to_weight_names.items():
shard_path = os.path.join(input_dir, shard)
shard_data = load_file(shard_path)
for weight_name in weight_names:
merged_state_dict[weight_name] = shard_data[weight_name]
del shard_data

Copilot uses AI. Check for mistakes.
Comment on lines +111 to +124
try:
new_key = get_mapped_key(normalized_key, _QWEN_3_5_TO_META)
except Exception:
# Ignore non-text weights and training-only extras (e.g., MTP).
if (
key.startswith("mtp.")
or key.startswith("model.visual.")
or ".vision_" in key
or key.startswith("visual.")
):
continue
# Ignore unsupported keys that are not required by the export model.
continue
converted_state_dict[new_key] = value
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qwen_3_5_to_meta() currently catches the exception from get_mapped_key(...) and then unconditionally continues for all unmapped keys. This makes conversion silently succeed even if a required text weight is missing from _QWEN_3_5_TO_META (e.g., due to a naming change), which can lead to hard-to-debug runtime failures later. Suggestion: explicitly whitelist/skip known non-text prefixes (visual/MTP/etc.), but for any other unexpected key either raise (default) or at least collect and print a summary of skipped keys behind a --verbose/--strict flag.

Copilot uses AI. Check for mistakes.
Comment on lines +101 to +109
# Legacy packed tensors (older checkpoints):
# in_proj_qkvz -> split into in_proj_qkv and in_proj_z
# in_proj_ba -> split into in_proj_b and in_proj_a
if normalized_key.endswith(".linear_attn.in_proj_qkvz.weight"):
pending_qkvz[normalized_key] = value
continue
if normalized_key.endswith(".linear_attn.in_proj_ba.weight"):
pending_ba[normalized_key] = value
continue
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The legacy packed tensor handling (*.linear_attn.in_proj_qkvz.weight and *.linear_attn.in_proj_ba.weight splitting) is new behavior but isn’t covered by the added unit test (which only uses the already-split in_proj_qkv/in_proj_z and in_proj_b/in_proj_a keys). Please add a small test case exercising both packed keys and validating the produced in_proj_qkv, in_proj_z, in_proj_b, and in_proj_a outputs (including shape checks).

Copilot uses AI. Check for mistakes.
@Phineas1500
Copy link
Contributor Author

@Phineas1500 thank you for adding this! Have you tried to export with the xnnpack recipe and test out the .pte file?

I exported with the XNNPACK recipe for Qwen3.5-0.8B, and I exported 4B in no-backend mode (my computer had memory constraints).

I'm in the process of testing 0.8B. Should have done that before making the PR, and I'll get back to you soon.

Just ran a forward pass successfully. Now I'm recreating the pte file and testing full generation (initially set max sequence length to 1 to save resources).

@Phineas1500
Copy link
Contributor Author

Copilot AI review requested due to automatic review settings March 3, 2026 00:56
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 20 out of 20 changed files in this pull request and generated 3 comments.

Comments suppressed due to low confidence (1)

examples/models/llama/norm.py:24

  • RMSNorm.__init__ now accepts add_unit_offset, but the docstring doesn’t describe what this flag does or how it changes the scaling (e.g., using (1 + weight) instead of weight). Please update the docstring/Args section so callers understand the semantics.
    def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False):
        """
        Initialize the RMSNorm normalization layer.

        Args:
            dim (int): The dimension of the input tensor.
            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

        Attributes:
            eps (float): A small value added to the denominator for numerical stability.
            weight (nn.Parameter): Learnable scaling parameter.

        """

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +129 to +135
# Ignore non-language-model keys up front.
if not (
normalized_key.startswith("model.") or normalized_key.startswith("lm_head.")
):
if _should_ignore_unmapped_key(key, normalized_key):
continue
continue
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In qwen_3_5_to_meta, keys that don’t start with model. or lm_head. are currently always skipped (even when they’re not in the explicit ignore list). This can silently drop unexpected checkpoint prefixes and produce a partially-converted/empty state dict without error. Consider raising a ValueError for non-text keys unless _should_ignore_unmapped_key(...) returns true (or at least logging them in non-verbose mode).

Suggested change
# Ignore non-language-model keys up front.
if not (
normalized_key.startswith("model.") or normalized_key.startswith("lm_head.")
):
if _should_ignore_unmapped_key(key, normalized_key):
continue
continue
# Ignore non-language-model keys up front, but fail on unexpected prefixes.
if not (
normalized_key.startswith("model.") or normalized_key.startswith("lm_head.")
):
if _should_ignore_unmapped_key(key, normalized_key):
continue
raise ValueError(
f"Unexpected non-language-model checkpoint key for Qwen3.5 export: {key}"
)

Copilot uses AI. Check for mistakes.
"model.layers.{}.linear_attn.in_proj_b.weight": "layers.{}.attention.in_proj_b.weight",
"model.layers.{}.linear_attn.in_proj_a.weight": "layers.{}.attention.in_proj_a.weight",
"model.layers.{}.linear_attn.conv1d.weight": "layers.{}.attention.conv1d.weight",
"model.layers.{}.linear_attn.conv1d.bias": "layers.{}.attention.conv1d.bias",
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mapping includes model.layers.*.linear_attn.conv1d.bias -> layers.*.attention.conv1d.bias, but the implemented DeltaNet conv (nn.Conv1d(..., bias=False)) has no bias parameter. If HF checkpoints contain this bias, it will become an unexpected key and be ignored by the loader (strict=False), potentially degrading correctness. Either remove/ignore the conv1d.bias mapping or enable a bias in the model and load it consistently.

Suggested change
"model.layers.{}.linear_attn.conv1d.bias": "layers.{}.attention.conv1d.bias",

Copilot uses AI. Check for mistakes.
Comment on lines +137 to +194
# Legacy packed tensors (older checkpoints):
# in_proj_qkvz -> split into in_proj_qkv and in_proj_z
# in_proj_ba -> split into in_proj_b and in_proj_a
if normalized_key.endswith(".linear_attn.in_proj_qkvz.weight"):
pending_qkvz[normalized_key] = value
continue
if normalized_key.endswith(".linear_attn.in_proj_ba.weight"):
pending_ba[normalized_key] = value
continue

try:
new_key = get_mapped_key(normalized_key, _QWEN_3_5_TO_META)
except Exception as err:
if _should_ignore_unmapped_key(key, normalized_key):
continue
raise ValueError(
f"Unexpected checkpoint key not mapped for Qwen3.5 export: {key}"
) from err
converted_state_dict[new_key] = value

for key, value in pending_qkvz.items():
layer_match = re.search(r"model\.layers\.(\d+)\.", key)
if layer_match is None:
raise ValueError(f"Failed to parse layer id from key: {key}")
layer_id = layer_match.group(1)
out_proj_key = f"layers.{layer_id}.attention.out_proj.weight"
if out_proj_key not in converted_state_dict:
raise ValueError(
f"Cannot split {key}: missing {out_proj_key} to infer value dimension."
)

value_dim = converted_state_dict[out_proj_key].shape[1]
total_dim = value.shape[0]
conv_dim = total_dim - value_dim
if conv_dim <= 0 or (conv_dim - value_dim) % 2 != 0:
raise ValueError(
f"Invalid packed in_proj_qkvz shape for {key}: {tuple(value.shape)}"
)
qkv, z = torch.split(value, [conv_dim, value_dim], dim=0)
converted_state_dict[f"layers.{layer_id}.attention.in_proj_qkv.weight"] = qkv
converted_state_dict[f"layers.{layer_id}.attention.in_proj_z.weight"] = z
print(f"Split legacy packed key {key} -> in_proj_qkv + in_proj_z")

for key, value in pending_ba.items():
layer_match = re.search(r"model\.layers\.(\d+)\.", key)
if layer_match is None:
raise ValueError(f"Failed to parse layer id from key: {key}")
layer_id = layer_match.group(1)
if value.shape[0] % 2 != 0:
raise ValueError(
f"Invalid packed in_proj_ba shape for {key}: {tuple(value.shape)}"
)
half = value.shape[0] // 2
b, a = torch.split(value, [half, half], dim=0)
converted_state_dict[f"layers.{layer_id}.attention.in_proj_b.weight"] = b
converted_state_dict[f"layers.{layer_id}.attention.in_proj_a.weight"] = a
print(f"Split legacy packed key {key} -> in_proj_b + in_proj_a")

Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Legacy packed key support (*.linear_attn.in_proj_qkvz.weight and *.linear_attn.in_proj_ba.weight) is new behavior here, but the unit tests don’t cover these split paths. Adding a small test case that includes packed tensors and asserts the expected split keys/shapes would help prevent regressions.

Copilot uses AI. Check for mistakes.
backend:
xnnpack:
enabled: True
extended_ops: True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Phineas1500 thanks for putting up this PR! I'm wondering if you were able to successfully export to executorch using this config?

Copy link
Contributor Author

@Phineas1500 Phineas1500 Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! I validated qwen3_5_0_8b with qwen3_5_xnnpack_fp32.yaml, and I exported successfully to .pte. I asked the model what 2+2 was, and it said 4 😂

I also successfully exported qwen3_5_4b with xnnpack not enabled (my MacBook Air isn't powerful enough to fully do it).

I exported qwen3_5_4b with xnnpack successfully on Google Colab with max_seq_length and max_context_length equalling 128, and the resulting qwen3_5_4b_no_backend_smoke.pte was 15.5GB, while the qwen3_5_4b_xnnpack_fp32_128.pte was 18GB.

Currently trying to export qwen3_5_4b with xnnpack and q8da4w so I can try running it on a phone. Hopefully should work 🤞

Note, I was using the code from this PR and from #17801

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, that's great to hear! Thank you

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tested this exact config for 0.8B, 2B, and 4B, and export to pte was successful. models were able to answer questions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Also not for this PR but it would be great to add a CI test for the 0.8B version that tests e2e export, lower and run with xnnpack.

Copilot AI review requested due to automatic review settings March 5, 2026 01:28
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 20 out of 20 changed files in this pull request and generated no new comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@Phineas1500
Copy link
Contributor Author

In these commits, I addressed concerns brought up. Haven't yet addressed Jacob's comment about dynamic shape, but I might make a separate PR for it. Also made changes to #17801

Copy link
Contributor

@lucylq lucylq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Phineas1500 thanks so much for your patience, I think the PR is in good shape. I took another pass over it and please take a look at the comments - mostly nits to reduce code bloat / duplication and make this easier to maintain going forwards.


def test_qwen35_full_attention_forward_shape(self):
torch.manual_seed(0)
args = ModelArgs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we extract these out into a function, e.g.

def _make_args(self, **kwargs):
          defaults = dict(
              dim=32, n_layers=1, n_heads=4, n_kv_heads=2,
              head_dim=8, hidden_dim=64, max_seq_len=16,
              max_context_len=16,
          )
          defaults.update(kwargs)
          return ModelArgs(**defaults)

return self.k_cache, self.v_cache


@register_attention("qwen3_5_full")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need to register qwen3_5_full? Seems like we can just use mha - and tests use a mix of both.

if not (
normalized_key.startswith("model.") or normalized_key.startswith("lm_head.")
):
if _should_ignore_unmapped_key(key, normalized_key):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: might be simpler to check the valid ones inline

if not (key.startswith("model.layers.") or key.startswith("model.embed")
          or key.startswith("model.norm") or key.startswith("lm_head")):
      continue

Copy link
Contributor

@lucylq lucylq Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Phineas1500 - so I think we can remove _should_ignore_unmapped_key, the function definition and the _IGNORED_UNMAPPED_* items

Ah maybe we can't remove _IGNORED_UNMAPPED_SUFFIXES. That's fine then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made a commit here and on #17801 addressing your comments!


# Legacy packed tensors (older checkpoints):
# in_proj_qkvz -> split into in_proj_qkv and in_proj_z
# in_proj_ba -> split into in_proj_b and in_proj_a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a real use-case? Are there published checkpoints that use this format or can we drop it (and add it later if it becomes necessary)

return False


def _load_checkpoint_from_safetensors(input_dir: str) -> Dict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR but this is copied in each of the qwen directories (and potentially others). We should extract this out into a util in checkpoint.py

@Phineas1500
Copy link
Contributor Author

Phineas1500 commented Mar 5, 2026

I tried addressing your comments @lucylq . Let me know if anything else is needed (also applied the changes to #17801 )

Copilot AI review requested due to automatic review settings March 5, 2026 19:32
Copy link
Contributor

@lucylq lucylq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm, thank you for enabling Qwen3.5 @Phineas1500! And for patiently iterating.

The linter signal is merge-blocking, please run lintrunner -a following the CONTRIBUTING guidelines before merging.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

Copilot AI review requested due to automatic review settings March 5, 2026 21:27
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 20 out of 20 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

return output * self.weight
if self.add_unit_offset:
return output * (1.0 + self.weight.float()).type_as(x)
return output * self.weight.type_as(x)
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change adds .type_as(x) to the default (non-add_unit_offset) path, which changes behavior for ALL existing models using RMSNorm, not just Qwen3.5. Previously, when x was bfloat16 and self.weight was float32, the multiplication would produce a float32 result. Now the weight is first cast to x's dtype (e.g. bfloat16), keeping the result in that dtype. While this makes the behavior consistent with the add_unit_offset=True branch and may be intentionally fixing dtype propagation, it could subtly change numerical results for existing models using lower-precision dtypes. Please confirm this is intentional and that downstream models have been validated.

Suggested change
return output * self.weight.type_as(x)
return output * self.weight

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings March 5, 2026 22:12
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 20 out of 20 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +4 to +19
from executorch.examples.models.qwen3_5.convert_weights import convert_weights

__all__ = ["Qwen3_5Model", "convert_weights"]


def __getattr__(name):
if name == "Qwen3_5Model":
from executorch.examples.models.llama.model import Llama2Model

class Qwen3_5Model(Llama2Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)

globals()["Qwen3_5Model"] = Qwen3_5Model
return Qwen3_5Model
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This __init__.py uses a lazy __getattr__-based pattern with a dynamically created Qwen3_5Model class, which is inconsistent with every other model package in this codebase (e.g., qwen2_5/__init__.py, qwen3/__init__.py, phi_4_mini/__init__.py, smollm2/__init__.py, gemma/__init__.py, etc.), all of which define the model class at module scope via a direct import. For consistency, consider matching the established pattern:

from executorch.examples.models.llama.model import Llama2Model
from executorch.examples.models.qwen3_5.convert_weights import convert_weights

class Qwen3_5Model(Llama2Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

__all__ = ["Qwen3_5Model", "convert_weights"]
Suggested change
from executorch.examples.models.qwen3_5.convert_weights import convert_weights
__all__ = ["Qwen3_5Model", "convert_weights"]
def __getattr__(name):
if name == "Qwen3_5Model":
from executorch.examples.models.llama.model import Llama2Model
class Qwen3_5Model(Llama2Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
globals()["Qwen3_5Model"] = Qwen3_5Model
return Qwen3_5Model
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
from executorch.examples.models.llama.model import Llama2Model
from executorch.examples.models.qwen3_5.convert_weights import convert_weights
class Qwen3_5Model(Llama2Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
__all__ = ["Qwen3_5Model", "convert_weights"]

Copilot uses AI. Check for mistakes.
Comment on lines +7 to +20
from typing import TYPE_CHECKING

__all__ = [
Llama2Model,
]
if TYPE_CHECKING:
from .model import Llama2Model

__all__ = ["Llama2Model"]


def __getattr__(name):
if name == "Llama2Model":
from .model import Llama2Model

return Llama2Model
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes the llama __init__.py from a direct import (from .model import Llama2Model) to a lazy __getattr__ pattern. All other model packages that import Llama2Model (including qwen2_5, qwen3, phi_4_mini, smollm2, gemma, granite, etc.) still use a direct top-level import from executorch.examples.models.llama.model. If there's no circular import or startup performance issue driving this change, it adds unnecessary complexity and inconsistency. Consider keeping the direct import pattern consistent with the rest of the codebase.

Copilot uses AI. Check for mistakes.
@Phineas1500
Copy link
Contributor Author

@pytorchbot unlabel "release notes: none"
@pytorchbot label "release notes: examples"

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 5, 2026

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: argument command: invalid choice: 'unlabel' (choose from 'merge', 'revert', 'rebase', 'label', 'drci', 'cherry-pick')

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci,cherry-pick} ...

Try @pytorchbot --help for more info.

@Phineas1500
Copy link
Contributor Author

@pytorchbot label "release notes: examples"

@pytorch-bot pytorch-bot bot added the release notes: examples Changes to any of our example LLMs integrations, such as Llama3 and Llava label Mar 5, 2026
@Phineas1500
Copy link
Contributor Author

Phineas1500 commented Mar 6, 2026

The macOS failures that persist seem unrelated to my Qwen3.5 additions (and all the other ones don't have to do with the code). The two tests seem to be failing on main as well. What do you advise I do, @lucylq ?

Also, pull / android / run-emulator has been stuck there for a while.

@lucylq lucylq merged commit 9d413ac into pytorch:main Mar 6, 2026
155 of 164 checks passed
@lucylq
Copy link
Contributor

lucylq commented Mar 6, 2026

@Phineas1500 yeah I think the macos tests are pre-existing, we're working on a fix 😅 . Just merged. Can you rebase #17801 ? I will take a look at it.

@Phineas1500
Copy link
Contributor Author

@Phineas1500 yeah I think the macos tests are pre-existing, we're working on a fix 😅 . Just merged. Can you rebase #17801 ? I will take a look at it.

Sure! Gonna do that right now

@larryliu0820
Copy link
Contributor

Thank you @Phineas1500 for pushing this through!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: examples Issues related to demos under examples/ release notes: examples Changes to any of our example LLMs integrations, such as Llama3 and Llava release notes: none Do not include this in the release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants